# SPDX-License-Identifier: Apache-2.0
# Standard
from typing import TYPE_CHECKING, Optional, Union
import threading
import time

# Third Party
import msgspec
import torch
import zmq

# First Party
from lmcache.config import LMCacheEngineMetadata
from lmcache.logging import init_logger
from lmcache.v1.cache_engine import LMCacheEngine
from lmcache.v1.config import LMCacheEngineConfig
from lmcache.v1.lookup_client.abstract_client import LookupClientInterface
from lmcache.v1.lookup_client.async_lookup_message import (
    LookupCleanupMsg,
    LookupRequestMsg,
    LookupResponseMsg,
)
from lmcache.v1.rpc_utils import (
    get_zmq_context,
    get_zmq_rpc_path_lmcache,
    get_zmq_socket,
)

if TYPE_CHECKING:
    # Third Party
    from vllm.config import VllmConfig

logger = init_logger(__name__)


# NOTE(Jiayi): Prefetch could load extra redundant cache if multiple
# workers has different hit tokens.
class LMCacheAsyncLookupClient(LookupClientInterface):
    """
    ZMQ-based lookup client that communicates with a lookup server.

    Related extra_config:
    - lookup_server_worker_ids:
        is a config to control create lookup server on some workers.
        if mla is not enabled, default is [];
        if mla is enabled, default is [0];
        - if lookup_server_worker_ids is [], start lookup server on all workers
        - if lookup_server_worker_ids is [0], start lookup server on worker0
        - if lookup_server_worker_ids is [0, 3, 6], start lookup server on
          worker0, worker3 and worker6
    """

    def __init__(
        self,
        vllm_config: "VllmConfig",
        config: LMCacheEngineConfig,
        metadata: LMCacheEngineMetadata,
    ):
        # lookup_id -> first lookup time
        # this helps us support timeout semantics
        self.first_lookup_time: dict[str, float] = {}
        self.timeout_ms = config.lookup_timeout_ms

        self.ctx = get_zmq_context(use_asyncio=False)
        rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
            "lmcache_rpc_port", 0
        )
        self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
        self.tensor_parallel_size = vllm_config.parallel_config.tensor_parallel_size
        self.num_ranks = self.tensor_parallel_size * self.pipeline_parallel_size
        self.lookup_server_worker_ids = config.get_lookup_server_worker_ids(
            metadata.use_mla, metadata.world_size
        )

        self.push_sockets = []
        if len(self.lookup_server_worker_ids) > 0:
            ranks = self.lookup_server_worker_ids
            self.num_ranks = len(self.lookup_server_worker_ids)
        else:
            ranks = [i for i in range(self.num_ranks)]

        for rank in ranks:
            worker_socket_path = get_zmq_rpc_path_lmcache(
                vllm_config, "lookup_worker", rpc_port, rank
            )
            logger.info(
                f"lmcache lookup client connect to rank {rank} "
                f"with worker socket path {worker_socket_path}"
            )

            push_socket = get_zmq_socket(
                self.ctx,
                worker_socket_path,
                "ipc",
                zmq.PUSH,  # type: ignore[attr-defined]
                "connect",
            )

            self.push_sockets.append(push_socket)

        scheduler_socket_path = get_zmq_rpc_path_lmcache(
            vllm_config, "lookup_scheduler", rpc_port, 0
        )
        self.pull_socket = get_zmq_socket(
            self.ctx,
            scheduler_socket_path,
            "ipc",
            zmq.PULL,  # type: ignore[attr-defined]
            "bind",
        )
        logger.info(
            f"lmcache lookup client connect to scheduler "
            f"with socket path {scheduler_socket_path}"
        )

        # First Party
        from lmcache.v1.token_database import (
            ChunkedTokenDatabase,
            SegmentTokenDatabase,
            TokenDatabase,
        )

        self.token_database: TokenDatabase
        if config.enable_blending:
            self.token_database = SegmentTokenDatabase(config, metadata)
        else:
            self.token_database = ChunkedTokenDatabase(config, metadata)

        # A lock is needed since we need another thread to pull
        # responses from the lookup_and_prefetch server
        # (e.g., worker process).
        self.lock = threading.Lock()

        # map from lookup_id (i.e., req_id) to req's status.
        # None indicates ongoing.
        # int indicates number of hit tokens.
        self.reqs_status: dict[str, Optional[int]] = {}

        # map from lookup_id (i.e., req_id) to number of hit tokens for each worker
        self.res_for_each_worker: dict[str, list[int]] = {}

        # The two parts are [lookup_id (i.e., req_id), num_hit_tokens]
        self.num_parts = 2

        # Track lookup_ids that have been aborted for cleanup
        self.aborted_lookups: set[str] = set()

        self.running = True

        self.thread = threading.Thread(
            target=self.process_responses_from_workers, daemon=True
        )
        self.thread.start()

        # default backoff time
        self.lookup_backoff_time = 0.01
        if config.extra_config is not None:
            self.lookup_backoff_time = float(
                config.extra_config.get("lookup_backoff_time", self.lookup_backoff_time)
            )

    def lookup_cache(self, lookup_id: str) -> Optional[int]:
        """
        -1 means not found;
        None means ongoing;
        int >= 0 means number of hit tokens
        """
        # Check if any aborted lookups are finished, send cleanup messages
        self._cleanup_finished_aborted_lookups()

        with self.lock:
            if (req_status := self.reqs_status.get(lookup_id, -1)) == -1:
                self.reqs_status[lookup_id] = None
                self.first_lookup_time[lookup_id] = time.time()
            elif req_status is None:
                time.sleep(self.lookup_backoff_time)
                if (
                    time.time() - self.first_lookup_time[lookup_id]
                ) * 1000 > self.timeout_ms:
                    logger.warning(
                        (
                            "Request %s is still waiting for async lookup "
                            "after %d seconds, returning 0 lmcache cached tokens "
                            "so vllm can recompute"
                        ),
                        lookup_id,
                        self.timeout_ms // 1000,
                    )
                    self.first_lookup_time.pop(lookup_id, None)
                    return 0

            return req_status

    # TODO(Jiayi): Consider batching here
    def lookup(
        self,
        token_ids: Union[torch.Tensor, list[int]],
        lookup_id: str,
        request_configs: Optional[dict] = None,
        num_computed_tokens: int = 0,
    ) -> Optional[int]:
        # TODO(ziruiliu) num_computed_tokens is passed in
        # Due to complexity in preemption or eviction in async mode
        # We will handle skipping logic in future.

        hashes: list[int] = []
        offsets = []
        for start, end, hash_val in self.token_database.process_tokens(
            token_ids, make_key=False
        ):
            hashes.append(hash_val)  # type: ignore[arg-type]
            offsets.append(end - start)

        # Create structured message
        msg = LookupRequestMsg(
            lookup_id=lookup_id,
            hashes=hashes,
            offsets=offsets,
            request_configs=request_configs,
        )

        # Serialize message using msgspec
        msg_buf = msgspec.msgpack.encode(msg)

        for i in range(self.num_ranks):
            self.push_sockets[i].send(msg_buf, copy=False)
        time.sleep(self.lookup_backoff_time)
        return None

    def process_responses_from_workers(self):
        while self.running:
            try:
                msg_buf = self.pull_socket.recv(copy=False)
                # Deserialize message using msgspec
                msg = msgspec.msgpack.decode(msg_buf, type=LookupResponseMsg)
                lookup_id = msg.lookup_id
                res = msg.num_hit_tokens

                with self.lock:
                    if lookup_id not in self.res_for_each_worker:
                        self.res_for_each_worker[lookup_id] = [res]
                    else:
                        self.res_for_each_worker[lookup_id].append(res)
                    all_res = self.res_for_each_worker[lookup_id]

                    if len(all_res) == self.num_ranks:
                        self.res_for_each_worker.pop(lookup_id)

                        # NOTE: it is possible that the number of hit
                        # tokens is different across (TP and PP) ranks, so we
                        # can use the minimum value as the number of
                        # hit tokens.
                        self.reqs_status[lookup_id] = min(all_res)

            except Exception as e:
                logger.error(f"Error processing response from worker: {e}")

    def clear_lookup_status(self, lookup_id: str) -> None:
        with self.lock:
            self.reqs_status.pop(lookup_id, None)
            self.first_lookup_time.pop(lookup_id, None)

    def cancel_lookup(self, lookup_id: str) -> None:
        """Mark lookup as aborted. Cleanup will happen after task finishes."""
        self.aborted_lookups.add(lookup_id)

    def _cleanup_finished_aborted_lookups(self) -> None:
        """Check for finished aborted lookups and send cleanup messages to workers."""
        # A lookup whose status is None is still loading.
        # We wait for it to finish before cleanup.
        finished_lookups = [
            lookup_id
            for lookup_id in self.aborted_lookups
            if self.reqs_status.get(lookup_id) is not None
        ]
        if finished_lookups:
            self.aborted_lookups.difference_update(finished_lookups)

        # Tell the server to free the reserved memory buffers for each aborted lookup.
        for lookup_id in finished_lookups:
            self._send_cleanup_message(lookup_id)
            self.clear_lookup_status(lookup_id)

    def _send_cleanup_message(self, lookup_id: str) -> None:
        """Send cleanup message to workers to release memory objects."""
        msg = LookupCleanupMsg(lookup_id=lookup_id)
        msg_buf = msgspec.msgpack.encode(msg)

        for i in range(self.num_ranks):
            self.push_sockets[i].send(msg_buf, copy=False)
        logger.debug("Sent cleanup message for lookup_id=%s", lookup_id)

    def supports_producer_reuse(self) -> bool:
        """Return True as LMCacheLookupClient supports producer kvcache reuse"""
        return True

    def close(self):
        self.running = False
        try:
            if self.thread.is_alive():
                self.thread.join(timeout=1.0)
            for s in self.push_sockets:
                s.close(linger=0)  # type: ignore[arg-type]
            self.pull_socket.close(linger=0)  # type: ignore[arg-type]
            self.ctx.term()
        except Exception as e:
            logger.warning(f"Failed to join thread during close: {e}")


class LMCacheAsyncLookupServer:
    """ZMQ-based async lookup server that handles lookup and prefetch
    requests using LMCacheEngine."""

    def __init__(self, lmcache_engine: LMCacheEngine, vllm_config: "VllmConfig"):
        self.ctx = zmq.Context()  # type: ignore[attr-defined]
        rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
            "lmcache_rpc_port", 0
        )
        worker_socket_path = get_zmq_rpc_path_lmcache(
            vllm_config, "lookup_worker", rpc_port, vllm_config.parallel_config.rank
        )
        scheduler_socket_path = get_zmq_rpc_path_lmcache(
            vllm_config, "lookup_scheduler", rpc_port, 0
        )
        self.push_socket = get_zmq_socket(
            self.ctx,
            scheduler_socket_path,
            "ipc",
            zmq.PUSH,  # type: ignore[attr-defined]
            "connect",
        )
        self.pull_socket = get_zmq_socket(
            self.ctx,
            worker_socket_path,
            "ipc",
            zmq.PULL,  # type: ignore[attr-defined]
            "bind",
        )

        self.lmcache_engine = lmcache_engine
        self.running = True

        logger.info(
            "lmcache lookup server start with"
            f" scheduler socket path {scheduler_socket_path}, "
            f"worker socket path {worker_socket_path}"
        )
        self.thread = threading.Thread(
            target=self.process_requests_from_scheduler, daemon=True
        )
        self.thread.start()

    def process_requests_from_scheduler(self):
        while self.running:
            try:
                msg_buf = self.pull_socket.recv(copy=False)
                # rely on msgspec to automatically discriminate
                # between LookupRequestMsg and LookupCleanupMsg
                msg = msgspec.msgpack.decode(
                    msg_buf,
                    type=Union[LookupRequestMsg, LookupCleanupMsg],
                )

                if isinstance(msg, LookupRequestMsg):
                    # Handle lookup request
                    self.lmcache_engine.async_lookup_and_prefetch(
                        lookup_id=msg.lookup_id,
                        hashes=msg.hashes,
                        offsets=msg.offsets,
                        pin=True,
                        request_configs=msg.request_configs,
                    )

                elif isinstance(msg, LookupCleanupMsg):
                    # Handle cleanup request - release memory objects for aborted lookup
                    self.lmcache_engine.cleanup_memory_objs(msg.lookup_id)

                else:
                    logger.warning("Unknown message type: %s", type(msg))

            except Exception as e:
                logger.error("Error processing request from scheduler: %s", e)

    def send_response_to_scheduler(self, lookup_id: str, num_hit_tokens: int):
        # Create structured response message
        msg = LookupResponseMsg(
            lookup_id=lookup_id,
            num_hit_tokens=num_hit_tokens,
        )

        # Serialize message using msgspec
        msg_buf = msgspec.msgpack.encode(msg)
        self.push_socket.send(msg_buf, copy=False)

    def close(self):
        self.running = False
        try:
            if self.thread.is_alive():
                self.thread.join(timeout=1.0)
            for s in self.push_sockets:
                s.close(linger=0)  # type: ignore[arg-type]
            self.pull_socket.close(linger=0)  # type: ignore[arg-type]
            self.ctx.term()
        except Exception as e:
            logger.warning(f"Failed to join thread during close: {e}")
